import base_util
from modelscope import  snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 默认的聊天模型加载基类
class BaseChatModel():

    def __init__(self, model_name):
        #
        self.model_download_path = f"{base_util.current_project_dir}\\\models\ChatModels\\"
        self.model_name = model_name
        self.model = None

    def load_model(self,model_name):

        pass

    def chat(self,msg,history):
        pass

#     视情况扩展方法 后期可能有一个超级通用类。 把所有的流程整合到一起。 然后直接  .emb .rerank . search 都在一个class里面就搞定了




class internlm(BaseChatModel):
    # 加载书生模型
    def __init__(self, model_name='Shanghai_AI_Laboratory/internlm-chat-20b'):
        # super().__init__('BAAI/bge-large-zh-v1.5')
        super().__init__(model_name)
        self.load_model(model_name)


    def load_model(self,model_name):
        self.model_path=snapshot_download(model_name,cache_dir=self.model_download_path)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        # `torch_dtype=torch.float16` 可以令模型以 float16 精度加载，否则 transformers 会将模型加载为 float32，导致显存不足
        self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto", torch_dtype=torch.bfloat16,
                                                     trust_remote_code=True).eval()

        return self.model

    def chat(self,msg,history):
        output, history = self.model.chat(self.tokenizer, msg,history)
        return output,history

class chatGLM(BaseChatModel):
    # 加载chatglm聊天模型

    pass